import os
import inspect
os.environ["CUDA_VISIBLE_DEVICES"] = "2" 
from typing import List, Optional, Tuple, Union
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
from transformers import PatchTSTModel, PatchTSTConfig, TrainingArguments, EarlyStoppingCallback, Trainer, PatchTSTForPrediction, PatchTSMixerConfig, PatchTSMixerForPrediction
#from reservoir_computing.modules import RC_model
from configuration import ReservoirTConfig
from tqdm import tqdm
from datasets import Dataset
from neuralforecast.models import DLinear
import wandb
from transformers.modeling_outputs import (
    BaseModelOutputWithPastAndCrossAttentions,
    BaseModelOutputWithPoolingAndCrossAttentions,
    CausalLMOutputWithCrossAttentions,
    MaskedLMOutput,
    MultipleChoiceModelOutput,
    NextSentencePredictorOutput,
    QuestionAnsweringModelOutput,
    SequenceClassifierOutput,
    TokenClassifierOutput,
)



wandb.init(
    # set the wandb project where this run will be logged
    project="my-awesome-project",

    # track hyperparameters and run metadata
    config={
    "learning_rate": 0.002,
    "epochs": 100,
    }
)

configuration = ReservoirTConfig()

configuration.output_size=140
configuration.prediction_length = 720
configuration.context_length = 336
configuration.re_output_size=21
configuration.max_sequence_length=1000
configuration.sequence_length=336
configuration.pred_len=720
configuration.hidden_size=7
configuration.num_attention_heads=7
configuration.hidden_dropout_prob=0.1
configuration.num_hidden_layers=16
configuration.num_reservoirs = 10
configuration.intermediate_size=128
configuration.reservoir_size = [10, 11,12, 13, 14, 15, 16,  17, 18, 19]
configuration.spectral_radius = [0.6, 0.8, 0.55, 0.6, 0.5, 0.4, 0.3, 0.2, 0.81, 0.05]
configuration.sparsity = [0.6, 0.55, 0.5, 0.45, 0.4, 0.35, 0.3, 0.25, 0.2, 0.15]
configuration.leaky = [0.4, 0.41, 0.42, 0.43, 0.44, 0.45, 0.46, 0.47, 0.48, 0.49]
#configuration.reservoir_size = 1000
configuration.attention_probs_dropout_prob=0.0
configuration.batch_size = 16
configuration.embedding_size = 140
configuration.embedding_type = 2
configuration.num_heads = 7


class TimeSeriesEmbedding(nn.Module):
    def __init__(self,config):
        
        super(TimeSeriesEmbedding, self).__init__()
        self.hidden_size = config.hidden_size
        self.embedding_size = config.embedding_size
        self.embedding_type = config.embedding_type
        self.sequence_length = config.sequence_length
        self.feature_as_token_each_feature_emb_size = int(config.embedding_size/config.hidden_size)
        self.query = nn.Linear(self.hidden_size,self.embedding_size)
        self.key = nn.Linear(self.hidden_size,self.embedding_size)
        self.value = nn.Linear(self.hidden_size,self.embedding_size)
        self.multihead_attn = nn.MultiheadAttention(self.embedding_size, num_heads=config.num_heads,batch_first=True)
        self.batch_size = config.batch_size
        self.feature_as_token_weights = nn.ModuleList([nn.Linear(1, self.feature_as_token_each_feature_emb_size) for _ in range(self.hidden_size)])



    def forward(self,input_ids,key_values_input_ids = None):
        input_ids = input_ids.float()
        if self.embedding_type == 1:
            query = self.query(input_ids)
            if key_values_input_ids is not None:
                key = self.key(key_values_input_ids)
                value = self.key(key_values_input_ids)
            else: 
                key = self.key(input_ids)
                value = self.key(input_ids)
            attn_output, attn_weights = self.multihead_attn(query, key, value)

            return attn_output
    
        if self.embedding_type == 2:
            fl_inputs_embeds_list = []
            for i in range(self.hidden_size):
                input_features_seq_scale = input_ids[:,:,i] #(sample_size, time_length)
                input_features_seq = input_features_seq_scale.unsqueeze(-1)
                input_featrues_embeds = self.feature_as_token_weights[i](input_features_seq) #(sample_size,time_length,num_features)
                fl_inputs_embeds_list.append(input_featrues_embeds * input_features_seq_scale.unsqueeze(-1))  # Broadcasting to match shape

                

            fl_input_embeds = torch.cat(fl_inputs_embeds_list, dim=-1)  # Shape: (batch_size, seq_length, total_embed_dim)
            return fl_input_embeds
        

class DeepReservoirNet(nn.Module):
    def __init__(self, config, reservoir_size=1000, spectral_radius=0.9, leaky=0.3, sparsity=0.5):
        super(DeepReservoirNet, self).__init__()

        self.input_size = config.sequence_length
        self.reservoir_size = reservoir_size
        self.output_size = config.re_output_size
        self.spectral_radius = spectral_radius
        self.leaky = leaky

        self.W_in = nn.Linear(self.input_size, reservoir_size, bias=False).float()
        self.W_in.weight.requires_grad = False
        self.W_res = nn.Linear(reservoir_size, reservoir_size, bias=False).float()
        self.W_res.weight.requires_grad = False
        #self.W_out = nn.Linear(reservoir_size, self.output_size).float()
        #self.W_out.weight.requires_grad = False
        self.res_state = torch.zeros(1, reservoir_size).float()

        self.act= nn.Tanh()

        self.W_res_norm = self.compute_spectral_radius(sparsity)
        self.self_attention = nn.MultiheadAttention(self.output_size, config.num_attention_heads, dropout=0.2)


    def compute_spectral_radius(self, sparsity=0.5):
        with torch.no_grad():
            self.W_res.weight.data = torch.randn(self.reservoir_size, self.reservoir_size)
            # set a fraction of the entries to zero
            num_zeros = int(sparsity * self.reservoir_size ** 2)
            idxs = torch.randperm(self.reservoir_size ** 2)[:num_zeros]
            self.W_res.weight.data.view(-1)[idxs] = 0

            eigenvals = torch.linalg.eigvals(self.W_res.weight)
            radius = torch.max(torch.abs(eigenvals))
            self.W_res.weight.data /= radius
        return radius
    def forward(self, input_data, res_state):
        #print()
        # Compute reservoir state
        outputs = []
        #if res_state == None:
        #   res_state = self.res_state.clone()
        
        batch_size = input_data.shape[0]
        input_data = input_data.permute(0, 2, 1)
        for t in range(batch_size):

            i_data = input_data[t]

            #print("i_data", i_data.shape)
            input_proj = self.W_in(i_data.float())

            res_proj = self.W_res(res_state)

            # print('res_state', res_state.shape)
            #print('input_proj', input_proj.shape)
            #print('res_proj', res_proj.shape)

            res_state = (1 - self.leaky) * res_state + self.leaky * F.tanh(input_proj + res_proj)
            #print('fres_state', res_state.shape)
            #print( (1 - self.leaky), (0.2*res_state).shape)
            # Normalize reservoir state
            res_state = res_state / self.W_res_norm
            #print('here-1',res_state.shape )

            # Compute output
            # output = self.W_out(res_state)
            #print('ddd',output.shape)
            # Permute output to shape (sequence_length, batch_size, output_size)

            #output, self_attention_weights = self.self_attention(output, output, output)
            # Permute output back to shape (batch_size, sequence_length, output_size)
            #print("output.shape")
            #print("res_state shape:",res_state.shape)

            outputs.append(res_state.squeeze(0))
            #print("outputs lengt:", output)
        final_output = torch.stack(outputs, dim=0)
        #print("reservoir_output shape is:",final_output.permute(0, 2, 1).shape)

        return {'Output':final_output, "State": res_state}
    


class ReservoirTTimeSeries(nn.Module):
        # Initialize weights and apply final processing
        #self.post_init()
    def __init__(self, config):
        super().__init__()
        self.num_labels = config.num_labels
        self.config = config

        #self.bert_enc = BertGenerationEncoder(config)
        #self.bert_dec = BertGenerationDecoder(config)

        self.layer_norm = nn.LayerNorm(config.hidden_size)

        self.reservoirs=nn.ModuleList()
        self.id_train = None
        self.id_test = None
        self.reservoir_state = None
        self.state_ids = None

        for i in range(config.num_reservoirs):

            reservoir = DeepReservoirNet(config=config,
                                         reservoir_size=config.reservoir_size[i],
                                         spectral_radius=config.spectral_radius[i],
                                         leaky=config.leaky[i],
                                         sparsity=config.sparsity[i])

            self.reservoirs.append(reservoir)


    def forward(
        self,
        input_ids: Optional[torch.Tensor] = None,
        reservoir_ids: Optional[torch.Tensor] = None,
        state_ids: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        token_type_ids: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        labels_ids: Optional[torch.Tensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        dataset_type = None,
        train_dataset = None,
        eval_dataset = None,
        id = "id_train",
    ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
        r"""
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        """
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        #print(id)

        if dataset_type == "eval_dataset":
            sample_size,_,_ = inputs_embeds.shape
            inputs_embeds = torch.cat((train_dataset["inputs_embeds"], inputs_embeds), dim=0)
        elif dataset_type == "test_dataset":
            sample_size,_,_ = inputs_embeds.shape
            inputs_embeds = torch.cat((train_dataset["inputs_embeds"],
                                       eval_dataset["inputs_embeds"], 
                                       inputs_embeds), dim=0)

        if reservoir_ids is None:
            # Zero-pad the tensor in front
            padded_tensor = F.pad(inputs_embeds, (0, 0, 0, 0, 1, 0))  # (left, right, top, bottom) padding

            # Remove the last row
            reservoir_ids = padded_tensor[:-1]

        
        # Zero pad in front to make it (8, 8, 4)

        #print("reservoir_ids", reservoir_ids.shape, inputs_embeds.shape)
        #print("reservoir_ids",reservoir_ids)

        state_ids = [torch.zeros(self.config.hidden_size, self.config.reservoir_size[i]).float() for i in range(self.config.num_reservoirs)]
        reservoir_outputs=[]

        for i, reservoir in tqdm(enumerate(self.reservoirs)):

            reservoir_output = reservoir(reservoir_ids.float(), state_ids[i].to(inputs_embeds.device))
            output_re = reservoir_output['Output']
            res_state = reservoir_output['State']
            state_ids[i] = res_state

            reservoir_outputs.append(output_re)
            #if reservoir_outputs is not None:
            #    reservoir_outputs = torch.cat((reservoir_outputs,output_re), dim = 1)
            #else:
            #    reservoir_outputs = output_re
        
        #reservoir_outputs = reservoir_outputs/self.config.num_reservoirs
        # Transpose the lists

        transposed = list(zip(*reservoir_outputs))

        # Convert each tuple to a list (optional)
        reservoir_outputs = [list(tup) for tup in transposed]

        if dataset_type is None:
            return {"inputs_embeds":inputs_embeds,
                    "reservoir_outputs":reservoir_outputs,
                    "labels_ids":labels_ids}
        elif dataset_type == "eval_dataset":
            return {"inputs_embeds":inputs_embeds[-sample_size:],
                    "reservoir_outputs":reservoir_outputs[-sample_size:],
                    "labels_ids":labels_ids}
        elif dataset_type == "test_dataset":
            return {"inputs_embeds":inputs_embeds[-sample_size:],
                    "reservoir_outputs":reservoir_outputs[-sample_size:],
                    "labels_ids":labels_ids}


import torch
import torch.nn as nn
import torch.nn.functional as F

import torch
import torch.nn as nn

class moving_avg(nn.Module):
    """移动平均提取趋势分量"""
    def __init__(self, kernel_size, stride):
        super().__init__()
        self.avg = nn.AvgPool1d(kernel_size, stride=stride, padding=0)
        self.kernel_size = kernel_size

    def forward(self, x):
        front = x[:, 0:1, :].repeat(1, (self.kernel_size-1)//2, 1)
        end = x[:, -1:, :].repeat(1, (self.kernel_size-1)//2, 1)
        x_padded = torch.cat([front, x, end], dim=1)
        return self.avg(x_padded.permute(0,2,1)).permute(0,2,1)

class series_decomp(nn.Module):
    """时间序列分解模块"""
    def __init__(self, kernel_size):
        self.kernel_size = kernel_size
        super().__init__()
        self.moving_avg = moving_avg(kernel_size, stride=1)

    def forward(self, x):
        trend = self.moving_avg(x)
        seasonal = x - trend
        return seasonal, trend

class DLinear(nn.Module):
    def __init__(self, seq_len, pred_len, enc_in, individual=False):
        super().__init__()
        self.decomp = series_decomp(kernel_size=25)
        self.individual = individual
        
        if individual:  # 每个变量独立线性层
            self.Linear_Seasonal = nn.ModuleList([nn.Linear(seq_len, pred_len) for _ in range(enc_in)])
            self.Linear_Trend = nn.ModuleList([nn.Linear(seq_len, pred_len) for _ in range(enc_in)])
        else:           # 共享权重
            self.Linear_Seasonal = nn.Linear(seq_len, pred_len)
            self.Linear_Trend = nn.Linear(seq_len, pred_len)

    def forward(self, x):
        seasonal, trend = self.decomp(x)  # [Batch, Seq_len, Channel]
        seasonal = seasonal.permute(0,2,1)
        trend = trend.permute(0,2,1)
        
        if self.individual:
            seasonal_out = torch.stack([layer(seasonal[:,i,:]) for i, layer in enumerate(self.Linear_Seasonal)])
            trend_out = torch.stack([layer(trend[:,i,:]) for i, layer in enumerate(self.Linear_Trend)])
            output = seasonal_out + trend_out
        else:
            output = self.Linear_Seasonal(seasonal) + self.Linear_Trend(trend)
        
        return output.permute(0,2,1)  # [Batch, Pred_len, Channel]

class Reservoir_fl_model(nn.Module):
    def __init__(self, config):
        super().__init__()
        # 保持原有配置参数
        self.num_labels = config.num_labels
        self.config = config
        self.hidden_size = config.hidden_size
        self.re_output_size = config.re_output_size
        self.sequence_length = config.sequence_length
        self.batch_size = config.batch_size
        self.num_res = config.num_reservoirs
        
        # 替换为简化版DLinear
        self.dlinear = DLinear(seq_len=self.config.context_length,
                               pred_len=self.config.prediction_length,
                               enc_in=7,
                               individual=True)  # 使用你的config参数
        
        # 保持其他组件不变
        self.EmbeddingModel = TimeSeriesEmbedding(self.config)
        self.ReservoirModel = ReservoirTTimeSeries(self.config)
        self.crossattn = nn.MultiheadAttention(
            embed_dim=self.config.hidden_size,
            kdim=self.config.embedding_size,
            vdim=self.config.embedding_size,
            num_heads=7,
            batch_first=True,
            dropout=0.2
        )
        self.crossattn1 = nn.MultiheadAttention(
            embed_dim=self.config.hidden_size,
            kdim=self.config.embedding_size,
            vdim=self.config.embedding_size,
            num_heads=7,
            batch_first=True,
            dropout=0.2
        )
        self.decoder = nn.Linear(
            int(self.config.embedding_size/self.config.hidden_size)*self.config.hidden_size,
            self.config.hidden_size
        )
        self.W_outputs = nn.ModuleList([
            nn.Linear(config.reservoir_size[i], self.config.output_size).float()
            for i in range(self.num_res)
        ])

        self.norm1 = nn.LayerNorm(self.config.hidden_size)
        self.norm2 = nn.LayerNorm(self.config.hidden_size)
        self.norm3 = nn.LayerNorm(self.config.hidden_size)
        self.norm4 = nn.LayerNorm(self.config.hidden_size)
        # Feed Forward 层
        self.feed_forward = nn.Sequential(
            nn.Linear(self.config.hidden_size, self.config.hidden_size),
            nn.ReLU(),
            nn.Linear(self.config.hidden_size, self.config.hidden_size)
        )

        self.feed_forward1 = nn.Sequential(
            nn.Linear(self.config.hidden_size, self.config.hidden_size),
            nn.ReLU(),
            nn.Linear(self.config.hidden_size, self.config.hidden_size)
        )
        # Dropout
        self.dropout = nn.Dropout(0.3)

    def forward(self, inputs_embeds, reservoir_outputs, labels_ids=None):
        # 保持原有前向传播结构不变
        #inputs_embeds = self.EmbeddingModel(inputs_embeds)
        
        # 处理reservoir输出
        reservoir_outputs_fl = [
            W_out(output) 
            for output, W_out in zip(reservoir_outputs, self.W_outputs)
        ]
        reservoir_outputs = torch.cat(reservoir_outputs_fl, dim=1)
        
        # Cross Attention
        
        
        # DLinear预测（调整维度）
        dlinear_input = inputs_embeds.float()  # [batch, features, seq_len]
        prediction = self.dlinear(dlinear_input)
        prediction  = prediction.permute(2,1,0)
        
        prediction_crs, attn_weight = self.crossattn(
            prediction.float(),
            reservoir_outputs.float(),
            reservoir_outputs.float()
        )

        #bert_input_emb,_ = self.cross_attn(word_embeddings, reservoir_outputs, reservoir_outputs)
        bert_input_emb = self.norm1(prediction_crs + self.dropout(prediction.float()))
        ff_output = self.feed_forward(bert_input_emb)
        prediction = self.norm2(bert_input_emb + self.dropout(ff_output))
        
        prediction_crs, attn_weight = self.crossattn1(
            prediction.float(),
            reservoir_outputs.float(),
            reservoir_outputs.float()
        )

        bert_input_emb = self.norm3(prediction_crs + self.dropout(prediction.float()))
        ff_output = self.feed_forward1(bert_input_emb)
        prediction = self.norm4(bert_input_emb + self.dropout(ff_output))

        # 解码器（保持原有结构）
        #prediction = self.decoder(outputs.permute(0, 2, 1))  # 恢复维度
        
        # 计算损失（保持原有逻辑）
        loss = None
        if labels_ids is not None:
            loss = F.mse_loss(prediction, labels_ids.float())
            mae_loss = F.l1_loss(prediction, labels_ids.float())
            wandb.log({"mse_loss": loss, "mae_loss": mae_loss})
        
        return {
            "loss": loss,
            "prediction_outputs": prediction,
            "cros_attn_weights": attn_weight
        }

def extract_inputs_and_labels(dataset):

    loader = DataLoader(dataset, batch_size=16, shuffle=False)

    inputs_embeds_list = []
    labels_ids_list = []

    for batch in tqdm(loader, desc="Training Batches"):
        inputs_embeds_list.append(batch['inputs_embeds'])
        labels_ids_list.append(batch['labels_ids'])

    return {"inputs_embeds": torch.cat(inputs_embeds_list, dim=0),
            "labels_ids": torch.cat(labels_ids_list, dim=0)}


#from time_data_normalize import Dataset_ETT_hour
import numpy as np
# prepare data for lstm
from sklearn.preprocessing import StandardScaler
from pandas import read_csv
from pandas import DataFrame
import random
from sklearn.model_selection import train_test_split
from pandas import concat
from sklearn.preprocessing import LabelEncoder
from sklearn.preprocessing import MinMaxScaler
from torch.utils.data import Dataset

if __name__ == "__main__":


    dataset= read_csv('ETTh1.csv')
    dataset=dataset.dropna()
    dataset = dataset.drop(['date'], axis = 1)
    dataset = dataset.dropna()


    y = dataset.OT.values


    X = dataset.values

    scaler = StandardScaler()
    X = scaler.fit_transform(X)



    #X=X[1:]

    #Reservoir_id = np.array([[0] * len(X[0])] + X[:-1].tolist())
    # Create a zero column of shape (100, 1)
    '''
    zero_col = np.zeros((X.shape[0], 1))

    # Concatenate the original array with the zero column along the second axis (columns)
    X = np.hstack((X, zero_col))
    #X =  dataset.drop(['ate'], axis = 1).values

    #X_train, X_test, y_train, y_test =train_test_split(X.values, y, test_size=0.2, shuffle=False)
    '''

    from tqdm.auto import tqdm
    # 1. Preprocess the data into the required format
    def create_sequences(data, seq_length, pred_length):
        sequences = []
        targets = []
        for i in tqdm(range(len(data) - seq_length - pred_length + 1)):
            sequences.append(data[i:i+seq_length])
            targets.append(data[i+seq_length:i+seq_length+pred_length])
        return torch.tensor(sequences), torch.tensor(targets)

    X, y = create_sequences(X, seq_length=configuration.sequence_length, pred_length=configuration.pred_len)
    # Zeros tensor of shape [16941, 384, 1]
    # print(X.shape)
    #zeros = torch.zeros((X.size(0), X.size(1), 9), dtype=X.dtype)

    # Concatenate along the last dimension
    #X = torch.cat((X, zeros), dim=-1)


    batch=100
    indices = np.arange(len(X)) 
    barrier = int(len(indices)/batch)*batch
    indices = indices[0:barrier]
    soft_border = int((configuration.sequence_length/batch))+8

    indices = [indices[i:i+batch] for i in range(0, len(indices), batch)]

    border1 = int(len(indices)*0.9)
    border2 = border1+int(len(indices)*0.1)
    border3 = border2+int(len(indices)*0.2)

    train_ind = indices[0:border1]
    val_ind = indices[border1-soft_border: border2]
    test_ind = indices[border2-soft_border: border3]

    # random.shuffle(train_ind)
    # random.shuffle(val_ind)
    #random.shuffle(test_ind)


    X_train = [X[item] for sublist in train_ind for item in sublist]
    y_train = [y[item] for sublist in train_ind for item in sublist]

    X_val = [X[item] for sublist in val_ind for item in sublist]
    y_val = [y[item] for sublist in val_ind for item in sublist]

    X_test = [X[item] for sublist in test_ind for item in sublist]
    y_test = [y[item] for sublist in test_ind for item in sublist]

#train_indices, test_indices =train_test_split(indices,  test_size=0.2, shuffle=False)
#indices = [item for sublist in indices for item in sublist]

import torch
from torch.utils.data import Dataset, DataLoader

class CustomDataset(Dataset):
    def __init__(self, tokenized_inputs,  labels=None, pos=None):
        self.tokenized_inputs = tokenized_inputs
        self.labels = labels
        self.pos = pos
        self.id_list = None
        self.re = None

    def __len__(self):
        return len(self.tokenized_inputs)

    def __getitem__(self, idx):
        if self.labels is not None:
            return {
                "inputs_embeds": torch.tensor(self.tokenized_inputs[idx]).float(),
                "labels_ids": torch.tensor(self.labels[idx]).float(),
                #"id": torch.tensor(self.id_list[idx]),  # Include the id directly
                #"reservoir_ids": torch.tensor(self.re[idx]),
            }
        else:
            return {
                "inputs_embeds": torch.tensor(self.tokenized_inputs[idx]).float(),
            }

# Assuming you have X_train, y_train, X_test, y_test, trainpos, and testpos defined


if __name__ == "__main__":
    # print(X_train[0], flush=True)
    train_dataset = CustomDataset(X_train, y_train)
    # print(train_dataset[0], flush=True)

    val_dataset = CustomDataset(X_val, y_val)
    
    test_dataset = CustomDataset(X_test, y_test)

    preprocess = ReservoirTTimeSeries(configuration)
    train_dataset_dic = extract_inputs_and_labels(train_dataset)
    #val_dataset_dic = extract_inputs_and_labels(val_dataset)
    test_dataset_dic = extract_inputs_and_labels(val_dataset)
    #preprocess(inputs_embeds = train_dataset_dic["inputs_embeds"],labels_ids = train_dataset_dic["labels_ids"])

    from datasets import Dataset
    train_dataset_fl = Dataset.from_dict(preprocess(inputs_embeds = train_dataset_dic["inputs_embeds"],
                                                        labels_ids = train_dataset_dic["labels_ids"]))
    train_dataset_fl.set_format(type='torch')
    test_dataset_fl = Dataset.from_dict(preprocess(inputs_embeds = test_dataset_dic["inputs_embeds"],
                                                        labels_ids = test_dataset_dic["labels_ids"],
                                                        dataset_type = "eval_dataset",
                                                        train_dataset = train_dataset_dic))
    test_dataset_fl.set_format(type='torch')
    #print("train_dataset_fl input_embs",train_dataset_fl["inputs_embeds"][0].shape)
    #print("train_dataset_fl reservoir_outputs",train_dataset_fl["reservoir_outputs"][0].shape)


#embedding_model = TimeSeriesEmbedding(configuration)
#reservoir_model = ReservoirTTimeSeries(configuration)
#fl_model = Reservoir_fl_model(configuration)
#dataloader = DataLoader(train_dataset,batch_size=64,shuffle = False)

#for batch in dataloader:
#    inputs_embeds = batch["inputs_embeds"]
#    label_ids = batch["labels_ids"]
#    inputs_embeds = embedding_model(inputs_embeds)
#    #print(inputs_embeds.shape)
#    #reservoir_output,reservoir_state = reservoir_model(inputs_embeds = inputs_embeds)
#    result = fl_model(inputs_embeds = inputs_embeds)
#    break
            
training_args = TrainingArguments(
    output_dir="./checkpoint/patchtst/ETTh1/pretrain/last_hope_output/",
    overwrite_output_dir=True,
    learning_rate=0.002,
    num_train_epochs=100,
    do_eval=True,
    eval_strategy="epoch",
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    save_strategy="epoch",
    logging_strategy="epoch",
    save_total_limit=3,
    logging_dir="./checkpoint/patchtst/ETTh1/pretrain/logs/",  # Make sure to specify a logging directory
    load_best_model_at_end=True,  # Load the best model when training ends
    metric_for_best_model="eval_loss",  # Metric to monitor for early stopping
    greater_is_better=False,  # For loss
    label_names=["labels_ids"],
)

early_stopping_callback = EarlyStoppingCallback(
    early_stopping_patience=20,  # Number of epochs with no improvement after which to stop
    early_stopping_threshold=0.001,  # Minimum improvement required to consider as improvement
)
#print(train_dataset[0])
#print(train_dataset[0].keys())



class ReservoirTrainer(Trainer):
    def get_train_dataloader(self) -> DataLoader:
       train_dataset = self.train_dataset
       return DataLoader(train_dataset, shuffle=True, batch_size=16)
    
    def get_eval_dataloader(self, eval_dataset=None) -> DataLoader:
       if eval_dataset is None:
           eval_dataset = self.eval_dataset
       return DataLoader(eval_dataset, shuffle=True, batch_size=16)
    
    def get_test_dataloader(self, test_dataset=None) -> DataLoader:
       if test_dataset is None:
           test_dataset = self.test_dataset
       return DataLoader(test_dataset, shuffle=True, batch_size=16)

model  = Reservoir_fl_model(configuration)
if __name__ == "__main__":
    trainer = ReservoirTrainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset_fl,
        eval_dataset=test_dataset_fl,
        # callbacks=[early_stopping_callback],
        # compute_metrics=compute_metrics,
    )
#print(type(train_dataset_fl["labels_ids"]))
# pretrain
    trainer.train()
# Training loop

#res_state = torch.zeros(1, 1000)
#for batch in train_loader:
#    inputs_embeds = batch['inputs_embeds']  # Extract input sequences from the batch
#    labels_ids = batch['labels_ids']        # Extract target sequences from the batch#

    # Forward pass through the DeepReservoirNet
#    reservoir_outputs = model(inputs_embeds=inputs_embeds)

    # Get the model's outputs and updated reservoir state
    #outputs = output_dict['Output']
    #res_state = output_dict['State']
#    print(reservoir_outputs.shape) #the output shape is (batch_size,output_size,num_features)
#    print(reservoir_outputs) 
#    break  #next step is to keep track of all Reservoir states across all batches
           #next step is to use the cross attetnion to combine input and reservoir_outputs

# Step 4: Forward pass through the model
#output_dict = model(train_dataset)
#print(res_state.shape)